{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RMSProp算法\n",
    ":label:`sec_rmsprop`\n",
    "\n",
    " :numref:`sec_adagrad`中的关键问题之一，是学习率按预定时间表$\\mathcal{O}(t^{-\\frac{1}{2}})$显著降低。\n",
    "虽然这通常适用于凸问题，但对于深度学习中遇到的非凸问题，可能并不理想。\n",
    "但是，作为一个预处理器，Adagrad算法按坐标顺序的适应性是非常可取的。\n",
    "\n",
    " :cite:`Tieleman.Hinton.2012`建议以RMSProp算法作为将速率调度与坐标自适应学习率分离的简单修复方法。\n",
    "问题在于，Adagrad算法将梯度$\\mathbf{g}_t$的平方累加成状态矢量$\\mathbf{s}_t = \\mathbf{s}_{t-1} + \\mathbf{g}_t^2$。\n",
    "因此，由于缺乏规范化，没有约束力，$\\mathbf{s}_t$持续增长，几乎上是在算法收敛时呈线性递增。\n",
    "\n",
    "解决此问题的一种方法是使用$\\mathbf{s}_t / t$。\n",
    "对于$\\mathbf{g}_t$的合理分布来说，它将收敛。\n",
    "遗憾的是，限制行为生效可能需要很长时间，因为该流程记住了价值的完整轨迹。\n",
    "另一种方法是按动量法中的方式使用泄漏平均值，即$\\mathbf{s}_t \\leftarrow \\gamma \\mathbf{s}_{t-1} + (1-\\gamma) \\mathbf{g}_t^2$，其中参数$\\gamma > 0$。\n",
    "保持其他部所有分不变就产生了RMSProp算法。\n",
    "\n",
    "## 算法\n",
    "\n",
    "让我们详细写出这些方程式。\n",
    "\n",
    "$$\\begin{aligned}\n",
    "    \\mathbf{s}_t & \\leftarrow \\gamma \\mathbf{s}_{t-1} + (1 - \\gamma) \\mathbf{g}_t^2, \\\\\n",
    "    \\mathbf{x}_t & \\leftarrow \\mathbf{x}_{t-1} - \\frac{\\eta}{\\sqrt{\\mathbf{s}_t + \\epsilon}} \\odot \\mathbf{g}_t.\n",
    "\\end{aligned}$$\n",
    "\n",
    "常数$\\epsilon > 0$通常设置为$10^{-6}$，以确保我们不会因除以零或步长过大而受到影响。\n",
    "鉴于这种扩展，我们现在可以自由控制学习率$\\eta$，而不考虑基于每个坐标应用的缩放。\n",
    "就泄漏平均值而言，我们可以采用与之前在动量法中适用的相同推理。\n",
    "扩展$\\mathbf{s}_t$定义可获得\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "\\mathbf{s}_t & = (1 - \\gamma) \\mathbf{g}_t^2 + \\gamma \\mathbf{s}_{t-1} \\\\\n",
    "& = (1 - \\gamma) \\left(\\mathbf{g}_t^2 + \\gamma \\mathbf{g}_{t-1}^2 + \\gamma^2 \\mathbf{g}_{t-2} + \\ldots, \\right).\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "同之前在 :numref:`sec_momentum`小节一样，我们使用$1 + \\gamma + \\gamma^2 + \\ldots, = \\frac{1}{1-\\gamma}$。\n",
    "因此，权重总和标准化为$1$且观测值的半衰期为$\\gamma^{-1}$。\n",
    "让我们图像化各种数值的$\\gamma$在过去40个时间步长的权重。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/GradDescUtils.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/TrainingChapter11.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "float[] gammas = new float[]{0.95f, 0.9f, 0.8f, 0.7f};\n",
    "\n",
    "NDArray timesND = manager.arange(40f);\n",
    "float[] times = timesND.toFloatArray();\n",
    "display(GradDescUtils.plotGammas(times, gammas, 600, 400));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 从零开始实现\n",
    "\n",
    "和之前一样，我们使用二次函数$f(\\mathbf{x})=0.1x_1^2+2x_2^2$来观察RMSProp算法的轨迹。\n",
    "回想在 :numref:`sec_adagrad`一节中，当我们使用学习率为0.4的Adagrad算法时，变量在算法的后期阶段移动非常缓慢，因为学习率衰减太快。\n",
    "RMSProp算法中不会发生这种情况，因为$\\eta$是单独控制的。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "float eta = 0.4f;\n",
    "float gamma = 0.9f;\n",
    "\n",
    "Function<Float[], Float[]> rmsProp2d = (state) -> {\n",
    "    Float x1 = state[0], x2 = state[1], s1 = state[2], s2 = state[3];\n",
    "    float g1 = 0.2f * x1;\n",
    "    float g2 = 4 * x2;\n",
    "    float eps = (float) 1e-6;\n",
    "    s1 = gamma * s1 + (1 - gamma) * g1 * g1;\n",
    "    s2 = gamma * s2 + (1 - gamma) * g2 * g2;\n",
    "    x1 -= eta / (float) Math.sqrt(s1 + eps) * g1;\n",
    "    x2 -= eta / (float) Math.sqrt(s2 + eps) * g2;\n",
    "    return new Float[]{x1, x2, s1, s2};\n",
    "};\n",
    "\n",
    "BiFunction<Float, Float, Float> f2d = (x1, x2) -> {\n",
    "    return 0.1f * x1 * x1 + 2 * x2 * x2;\n",
    "};\n",
    "\n",
    "GradDescUtils.showTrace2d(f2d, GradDescUtils.train2d(rmsProp2d, 20));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![RmsProp Gradient Descent 2D.](https://d2l-java-resources.s3.amazonaws.com/img/chapter_optim-rmsprop-gd2d.svg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们在深度网络中实现RMSProp算法。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList initRmsPropStates(int featureDimension) {\n",
    "    NDManager manager = NDManager.newBaseManager();\n",
    "    NDArray sW = manager.zeros(new Shape(featureDimension, 1));\n",
    "    NDArray sB = manager.zeros(new Shape(1));\n",
    "    return new NDList(sW, sB);\n",
    "}\n",
    "\n",
    "public class Optimization {\n",
    "    public static void rmsProp(NDList params, NDList states, Map<String, Float> hyperparams) {\n",
    "        float gamma = hyperparams.get(\"gamma\");\n",
    "        float eps = (float) 1e-6;\n",
    "        for (int i = 0; i < params.size(); i++) {\n",
    "            NDArray param = params.get(i);\n",
    "            NDArray state = states.get(i);\n",
    "            // Update parameter and state\n",
    "            // state = gamma * state + (1 - gamma) * param.gradient^(1/2)\n",
    "            state.muli(gamma).addi(param.getGradient().square().mul(1 - gamma));\n",
    "            // param -= lr * param.gradient / sqrt(s + eps)\n",
    "            param.subi(param.getGradient().mul(hyperparams.get(\"lr\")).div(state.add(eps).sqrt()));\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将初始学习率设置为0.01，加权项$\\gamma$设置为0.9。\n",
    "也就是说，$\\mathbf{s}$累加了过去的$1/(1-\\gamma) = 10$次平方梯度观测值的平均值。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AirfoilRandomAccess airfoil = TrainingChapter11.getDataCh11(10, 1500);\n",
    "\n",
    "public TrainingChapter11.LossTime trainRmsProp(float lr, float gamma, int numEpochs) \n",
    "                    throws IOException, TranslateException {\n",
    "    int featureDimension = airfoil.getColumnNames().size();\n",
    "    Map<String, Float> hyperparams = new HashMap<>();\n",
    "    hyperparams.put(\"lr\", lr);\n",
    "    hyperparams.put(\"gamma\", gamma);\n",
    "    return TrainingChapter11.trainCh11(Optimization::rmsProp, \n",
    "                                       initRmsPropStates(featureDimension), \n",
    "                                       hyperparams, airfoil, \n",
    "                                       featureDimension, numEpochs);\n",
    "}\n",
    "\n",
    "trainRmsProp(0.01f, 0.9f, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 简洁实现\n",
    "\n",
    "我们可直接使用DJL中提供的RMSProp算法来训练模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Tracker lrt = Tracker.fixed(0.01f);\n",
    "Optimizer rmsProp = Optimizer.rmsprop().optLearningRateTracker(lrt).optRho(0.9f).build();\n",
    "\n",
    "TrainingChapter11.trainConciseCh11(rmsProp, airfoil, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* RMSProp算法与Adagrad算法非常相似，因为两者都使用梯度的平方来缩放系数。\n",
    "* RMSProp算法与动量法都使用泄漏平均值。但是，RMSProp算法使用该技术来调整按系数顺序的预处理器。\n",
    "* 在实验中，学习率需要由实验者调度。\n",
    "* 系数$\\gamma$决定了在调整每坐标比例时历史记录的时长。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 如果我们设置$\\gamma = 1$，实验会发生什么？为什么？\n",
    "1. 旋转优化问题以最小化$f(\\mathbf{x}) = 0.1 (x_1 + x_2)^2 + 2 (x_1 - x_2)^2$。收敛会发生什么？\n",
    "1. 试试在真正的机器学习问题上应用RMSProp算法会发生什么，例如在Fashion-MNIST上的训练。试验不同的取值来调整学习率。\n",
    "1. 随着优化的进展，需要调整$\\gamma$吗？RMSProp算法对此有多敏感？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
